R Markdown

Tensor splines with using only 2 covariates.

library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(ggplot2)
library(mgcv)
## Loading required package: nlme
## 
## Attaching package: 'nlme'
## The following object is masked from 'package:dplyr':
## 
##     collapse
## This is mgcv 1.8-31. For overview type 'help("mgcv-package")'.
library(purrr)
### define the fine spatial grids for each GP, the spatial resolution
### will be between -2 and 2 for all dimensions for simplicity
num_fine_int <- 40

fine_grid_list <- list(
  x1 = seq(-2, 2, length.out = num_fine_int+1),
  x2 = seq(-2, 2, length.out = num_fine_int+1)
)
fine_grid_list
## $x1
##  [1] -2.0 -1.9 -1.8 -1.7 -1.6 -1.5 -1.4 -1.3 -1.2 -1.1 -1.0 -0.9 -0.8 -0.7 -0.6
## [16] -0.5 -0.4 -0.3 -0.2 -0.1  0.0  0.1  0.2  0.3  0.4  0.5  0.6  0.7  0.8  0.9
## [31]  1.0  1.1  1.2  1.3  1.4  1.5  1.6  1.7  1.8  1.9  2.0
## 
## $x2
##  [1] -2.0 -1.9 -1.8 -1.7 -1.6 -1.5 -1.4 -1.3 -1.2 -1.1 -1.0 -0.9 -0.8 -0.7 -0.6
## [16] -0.5 -0.4 -0.3 -0.2 -0.1  0.0  0.1  0.2  0.3  0.4  0.5  0.6  0.7  0.8  0.9
## [31]  1.0  1.1  1.2  1.3  1.4  1.5  1.6  1.7  1.8  1.9  2.0
### set the true relationships per factor

### set the functional expressions

true_functions <- list(
  g1 = function(x, av){av$a0 + av$a1 * cos(av$a2 * pi * x)},
  g2 = function(x, av){av$a0 + av$a1 * cos(av$a2 * pi * x)}
)

### set the parameters of the functions
true_hypers <- list(
  g1 = list(a0 = 0, a1 = 1, a2 = 1),
  g2 = list(a0 = 0, a1 = 1, a2 = 1)
)
### define a wrapper function for executing the functions
run_factors <- function(myfunc, myx, myparams)
{
  myfunc(myx, myparams)
}
### calculate each of the factors over the fine grid
fine_true_factors <-pmap(list(true_functions,
                              fine_grid_list,
                              true_hypers),
                         run_factors)

fine_true_factors
## $g1
##  [1]  1.000000e+00  9.510565e-01  8.090170e-01  5.877853e-01  3.090170e-01
##  [6] -1.836970e-16 -3.090170e-01 -5.877853e-01 -8.090170e-01 -9.510565e-01
## [11] -1.000000e+00 -9.510565e-01 -8.090170e-01 -5.877853e-01 -3.090170e-01
## [16]  6.123234e-17  3.090170e-01  5.877853e-01  8.090170e-01  9.510565e-01
## [21]  1.000000e+00  9.510565e-01  8.090170e-01  5.877853e-01  3.090170e-01
## [26]  6.123234e-17 -3.090170e-01 -5.877853e-01 -8.090170e-01 -9.510565e-01
## [31] -1.000000e+00 -9.510565e-01 -8.090170e-01 -5.877853e-01 -3.090170e-01
## [36] -1.836970e-16  3.090170e-01  5.877853e-01  8.090170e-01  9.510565e-01
## [41]  1.000000e+00
## 
## $g2
##  [1]  1.000000e+00  9.510565e-01  8.090170e-01  5.877853e-01  3.090170e-01
##  [6] -1.836970e-16 -3.090170e-01 -5.877853e-01 -8.090170e-01 -9.510565e-01
## [11] -1.000000e+00 -9.510565e-01 -8.090170e-01 -5.877853e-01 -3.090170e-01
## [16]  6.123234e-17  3.090170e-01  5.877853e-01  8.090170e-01  9.510565e-01
## [21]  1.000000e+00  9.510565e-01  8.090170e-01  5.877853e-01  3.090170e-01
## [26]  6.123234e-17 -3.090170e-01 -5.877853e-01 -8.090170e-01 -9.510565e-01
## [31] -1.000000e+00 -9.510565e-01 -8.090170e-01 -5.877853e-01 -3.090170e-01
## [36] -1.836970e-16  3.090170e-01  5.877853e-01  8.090170e-01  9.510565e-01
## [41]  1.000000e+00
pmap_dfr(list(fine_grid_list,
              fine_true_factors,
              1:2),
         function(x, g, glabel){tibble::tibble(x = x, g = g) %>% 
             mutate(factor_name = glabel)}) %>% 
  ggplot(mapping = aes(x = x, y = g)) +
  geom_line(mapping = aes(group = factor_name), size = 1.15) +
  facet_wrap(~factor_name) +
  theme_bw()

Above we have defined 2 univariate smooth functions \[ g1(x) = a_0+a_1cos(a_2\pi x) \\ g2(x) = a_0+a_1cos(a_2\pi x) \\ \]

\[ f(x) = g1(x) * g2(x) \]

fine_true_latent_dfs <- 
  pmap(list(fine_grid_list,
            fine_true_factors,
            1:2),
       function(x, g, glabel){tibble::tibble(x = x, g = g) %>% 
           set_names(c(sprintf("x%d", glabel),
                       sprintf("g%d", glabel))) %>% 
           tibble::rowid_to_column(sprintf("x%d_id", glabel))})

fine_true_latent_dfs
## $x1
## # A tibble: 41 x 3
##    x1_id    x1        g1
##    <int> <dbl>     <dbl>
##  1     1 -2     1.00e+ 0
##  2     2 -1.9   9.51e- 1
##  3     3 -1.8   8.09e- 1
##  4     4 -1.7   5.88e- 1
##  5     5 -1.6   3.09e- 1
##  6     6 -1.5  -1.84e-16
##  7     7 -1.4  -3.09e- 1
##  8     8 -1.30 -5.88e- 1
##  9     9 -1.2  -8.09e- 1
## 10    10 -1.1  -9.51e- 1
## # … with 31 more rows
## 
## $x2
## # A tibble: 41 x 3
##    x2_id    x2        g2
##    <int> <dbl>     <dbl>
##  1     1 -2     1.00e+ 0
##  2     2 -1.9   9.51e- 1
##  3     3 -1.8   8.09e- 1
##  4     4 -1.7   5.88e- 1
##  5     5 -1.6   3.09e- 1
##  6     6 -1.5  -1.84e-16
##  7     7 -1.4  -3.09e- 1
##  8     8 -1.30 -5.88e- 1
##  9     9 -1.2  -8.09e- 1
## 10    10 -1.1  -9.51e- 1
## # … with 31 more rows
### create the tensor product of the three factors
fine_latent_tensor <- expand.grid(fine_grid_list,
            KEEP.OUT.ATTRS = FALSE,
            stringsAsFactors = FALSE) %>% 
  as.data.frame() %>% tbl_df() %>% 
  left_join(fine_true_latent_dfs[[1]], by = "x1") %>% 
  left_join(fine_true_latent_dfs[[2]], by = "x2") %>% 
  select(ends_with("_id"), x1:x2, g1, g2) %>% 
  tibble::rowid_to_column("fine_id")
## Warning: `tbl_df()` is deprecated as of dplyr 1.0.0.
## Please use `tibble::as_tibble()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
fine_latent_tensor
## # A tibble: 1,681 x 7
##    fine_id x1_id x2_id    x1    x2        g1    g2
##      <int> <int> <int> <dbl> <dbl>     <dbl> <dbl>
##  1       1     1     1 -2       -2  1.00e+ 0     1
##  2       2     2     1 -1.9     -2  9.51e- 1     1
##  3       3     3     1 -1.8     -2  8.09e- 1     1
##  4       4     4     1 -1.7     -2  5.88e- 1     1
##  5       5     5     1 -1.6     -2  3.09e- 1     1
##  6       6     6     1 -1.5     -2 -1.84e-16     1
##  7       7     7     1 -1.4     -2 -3.09e- 1     1
##  8       8     8     1 -1.30    -2 -5.88e- 1     1
##  9       9     9     1 -1.2     -2 -8.09e- 1     1
## 10      10    10     1 -1.1     -2 -9.51e- 1     1
## # … with 1,671 more rows
### visualize the sum of the first two factors
fine_latent_tensor %>% 
  ggplot(mapping = aes(x = x1, y = x2)) +
  geom_raster(mapping = aes(fill = g1 * g2)) +
  coord_equal() +
  scale_fill_viridis_c() +
  theme_bw()

### visualize the sum of the all 2 factors
fine_latent_tensor %>% 
  mutate(go = g1 * g2) %>% 
  ggplot(mapping = aes(x = x1, y = go)) +
  geom_line() +
  facet_wrap(~x2, labeller = "label_both") +
  theme_bw() +
  theme(axis.text = element_blank())

### generate the noisy observations
set.seed(434343)
sd_noise <- 0.2 # noise

set.seed(813123)
fine_df <- fine_latent_tensor %>% 
  mutate(go = g1 * g2,
         y = rnorm(n = n(), mean = go, sd = sd_noise))

fine_df
## # A tibble: 1,681 x 9
##    fine_id x1_id x2_id    x1    x2        g1    g2        go      y
##      <int> <int> <int> <dbl> <dbl>     <dbl> <dbl>     <dbl>  <dbl>
##  1       1     1     1 -2       -2  1.00e+ 0     1  1.00e+ 0  0.987
##  2       2     2     1 -1.9     -2  9.51e- 1     1  9.51e- 1  1.20 
##  3       3     3     1 -1.8     -2  8.09e- 1     1  8.09e- 1  1.25 
##  4       4     4     1 -1.7     -2  5.88e- 1     1  5.88e- 1  0.693
##  5       5     5     1 -1.6     -2  3.09e- 1     1  3.09e- 1  0.430
##  6       6     6     1 -1.5     -2 -1.84e-16     1 -1.84e-16 -0.174
##  7       7     7     1 -1.4     -2 -3.09e- 1     1 -3.09e- 1 -0.563
##  8       8     8     1 -1.30    -2 -5.88e- 1     1 -5.88e- 1 -0.109
##  9       9     9     1 -1.2     -2 -8.09e- 1     1 -8.09e- 1 -1.03 
## 10      10    10     1 -1.1     -2 -9.51e- 1     1 -9.51e- 1 -1.22 
## # … with 1,671 more rows
### look at the fine grid noisy data in log space
fine_df %>% 
  ggplot(mapping = aes(x = x1, y = y)) +
  geom_point() +
  facet_wrap(~x2, labeller = "label_both") +
  theme_bw() +
  theme(axis.text = element_blank())

look at the y response.

fine_df %>% 
  ggplot(mapping = aes(x = x1, y = x2)) +
  geom_raster(mapping = aes(fill = y)) +
  scale_fill_viridis_c() +
  coord_equal() +
  theme_bw()

### work with a coarse grid instead of all of the points in the fine grid
num_coarse_int <- 20

coarse_grid_list <- list(
  x1 = seq(-2, 2, length.out = num_coarse_int+1),
  x2 = seq(-2, 2, length.out = num_coarse_int+1)
)
coarse_grid_list
## $x1
##  [1] -2.0 -1.8 -1.6 -1.4 -1.2 -1.0 -0.8 -0.6 -0.4 -0.2  0.0  0.2  0.4  0.6  0.8
## [16]  1.0  1.2  1.4  1.6  1.8  2.0
## 
## $x2
##  [1] -2.0 -1.8 -1.6 -1.4 -1.2 -1.0 -0.8 -0.6 -0.4 -0.2  0.0  0.2  0.4  0.6  0.8
## [16]  1.0  1.2  1.4  1.6  1.8  2.0
coarse_grid <- expand.grid(coarse_grid_list,
                           KEEP.OUT.ATTRS = FALSE,
                           stringsAsFactors = FALSE) %>% 
  as.data.frame() %>% tbl_df()
coarse_grid
## # A tibble: 441 x 2
##        x1    x2
##     <dbl> <dbl>
##  1 -2        -2
##  2 -1.8      -2
##  3 -1.6      -2
##  4 -1.4      -2
##  5 -1.2      -2
##  6 -1        -2
##  7 -0.800    -2
##  8 -0.600    -2
##  9 -0.400    -2
## 10 -0.200    -2
## # … with 431 more rows
train_df <- fine_df %>% 
  right_join(coarse_grid, by = c("x1", "x2"))

train_df %>% count(x1)
## # A tibble: 21 x 2
##        x1     n
##     <dbl> <int>
##  1 -2        21
##  2 -1.8      21
##  3 -1.6      21
##  4 -1.4      21
##  5 -1.2      21
##  6 -1        21
##  7 -0.800    21
##  8 -0.600    21
##  9 -0.400    21
## 10 -0.200    21
## # … with 11 more rows
train_df %>% count(x2)
## # A tibble: 21 x 2
##        x2     n
##     <dbl> <int>
##  1 -2        21
##  2 -1.8      21
##  3 -1.6      21
##  4 -1.4      21
##  5 -1.2      21
##  6 -1        21
##  7 -0.800    21
##  8 -0.600    21
##  9 -0.400    21
## 10 -0.200    21
## # … with 11 more rows
### look at the true latent function in the log-space with respect to x2
train_df %>% 
  ggplot(mapping = aes(x = x1, y = go)) +
  geom_line() +
  geom_point(mapping = aes(y = y), color = "red") +
  facet_wrap(~ x2, labeller = "label_both") +
  theme_bw() +
  theme(axis.text = element_blank())

ti() vs te()

te() -> this performs a full tensor product with taking in to consideration of marginal terms. ti() -> this considers only pure interaction with no marginal effects.

Lets start with ti() fit

Fit two models

model_1 = ti(x1, x2)

model_2 = ti(x1) + ti(x2) + ti(x1,x2)

ti_model_marginal_interaction <- gam(y ~ ti(x1) + ti(x2) + ti(x1, x2, k = c(20,20)), data = train_df, method = 'ML')

ti_model_marginal_interaction %>% summary()
## 
## Family: gaussian 
## Link function: identity 
## 
## Formula:
## y ~ ti(x1) + ti(x2) + ti(x1, x2, k = c(20, 20))
## 
## Parametric coefficients:
##             Estimate Std. Error t value Pr(>|t|)
## (Intercept) 0.004450   0.008348   0.533    0.594
## 
## Approximate significance of smooth terms:
##               edf  Ref.df      F p-value    
## ti(x1)      3.666   3.933  5.394 0.00105 ** 
## ti(x2)      2.156   2.630  1.727 0.15634    
## ti(x1,x2) 157.164 211.687 19.634 < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## R-sq.(adj) =  0.905   Deviance explained =   94%
## -ML = 39.814  Scale est. = 0.030732  n = 441
gams_smooth_model <- gam(y ~ s(x1) + s(x2), data = train_df, method = 'REML')

gams_smooth_model %>% summary()
## 
## Family: gaussian 
## Link function: identity 
## 
## Formula:
## y ~ s(x1) + s(x2)
## 
## Parametric coefficients:
##             Estimate Std. Error t value Pr(>|t|)
## (Intercept)  0.00445    0.02716   0.164     0.87
## 
## Approximate significance of smooth terms:
##       edf Ref.df     F p-value
## s(x1)   1  1.000 0.330   0.566
## s(x2)   1  1.001 0.135   0.714
## 
## R-sq.(adj) =  -0.0035   Deviance explained = 0.106%
## -REML = 384.65  Scale est. = 0.32524   n = 441
te_model_interaction <- gam(y ~ te(x1, x2, k = c(20,20)), data = train_df, method = 'ML')

te_model_interaction %>% summary()
## 
## Family: gaussian 
## Link function: identity 
## 
## Formula:
## y ~ te(x1, x2, k = c(20, 20))
## 
## Parametric coefficients:
##             Estimate Std. Error t value Pr(>|t|)
## (Intercept) 0.004450   0.008252   0.539     0.59
## 
## Approximate significance of smooth terms:
##           edf Ref.df     F p-value    
## te(x1,x2) 176  234.3 18.29  <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## R-sq.(adj) =  0.907   Deviance explained = 94.4%
## -ML = 59.416  Scale est. = 0.030028  n = 441
vis.gam(te_model_interaction)

plot(te_model_interaction, residuals=TRUE, pch=2, las=2)

vis.gam(te_model_interaction, view=c("x1","x2"), plot.type = "contour", color= "topo")

library(mgcViz)
## Loading required package: qgam
## Loading required package: rgl
## Registered S3 method overwritten by 'GGally':
##   method from   
##   +.gg   ggplot2
## Registered S3 method overwritten by 'mgcViz':
##   method from  
##   +.gg   GGally
## 
## Attaching package: 'mgcViz'
## The following objects are masked from 'package:stats':
## 
##     qqline, qqnorm, qqplot
# coverting gam object to a getViz object for plotting gam terms
te_model_interaction <- getViz(te_model_interaction)
#plot method from mgcViz object helps in plotting 2D
plot(sm(te_model_interaction, 1)) + l_fitRaster() + l_fitContour()

#Convert viz object back to gam object
te_model_interaction <- getGam(te_model_interaction)
te_basis_matrix <- predict.gam(te_model_interaction, type = "lpmatrix")
#te_basis_matrix
te_basis_matrix %>% as.data.frame() %>% tbl_df() %>%
  select(-'(Intercept)') %>%
  tibble::rowid_to_column() %>%
  left_join(coarse_grid %>% tibble::rowid_to_column(), by = 'rowid') %>%
  tidyr::gather(key = "key", value = "value", -rowid, -x1, -x2) %>%
  tidyr::separate(key, 
                  c("te_word", "te_id"),
                  sep = '\\.',
                  fill = "right",
                  remove = FALSE) %>%
  mutate(id = factor(te_id, levels = as.vector(unique(te_id)))) %>%
  filter(id %in% seq(1, 400, by = 40)) %>%
  ggplot(mapping = aes(x = x1, y = x2)) +
  geom_raster(aes(fill = value)) +
  facet_wrap(~ id, labeller = "label_both") +
  scale_fill_viridis_b()

te_basis_matrix %>% as.data.frame() %>% tbl_df() %>%
  select(-'(Intercept)') %>%
  tibble::rowid_to_column() %>%
  left_join(coarse_grid %>% tibble::rowid_to_column(), by = 'rowid') %>%
  tidyr::gather(key = "key", value = "value", -rowid, -x1, -x2) %>%
  tidyr::separate(key, 
                  c("te_word", "te_id"),
                  sep = '\\.',
                  fill = "right",
                  remove = FALSE) %>%
  filter(te_id %in% seq(1, 40, by = 2)) %>%
  mutate(id = factor(key, levels = as.vector(unique(key)))) %>%
  ggplot(mapping = aes(x = x1, y = x2)) +
  geom_raster(aes(fill = value)) +
  facet_wrap(~ id, labeller = "label_both") +
  scale_fill_viridis_b() +
  theme_bw()

SmoothCon_te_basis = smoothCon(te(x1, x2, k=c(4,4)), data = train_df)[[1]]

SmoothCon_te_basis$X %>% as.data.frame() %>% tbl_df() %>% 
  tibble::rowid_to_column() %>% 
  left_join(coarse_grid %>% tibble::rowid_to_column(),
            by = "rowid") %>%
  tidyr::gather(key = "key", value = "value", -rowid, -x1, -x2) %>%
  mutate(id = factor(key, levels = as.vector(unique(key)))) %>%
  ggplot(mapping = aes(x = x1, y = x2)) +
  geom_raster(mapping = aes(fill = value)) +
  facet_wrap(~ id, labeller = "label_both") +
  scale_fill_viridis_b() +
  theme_bw()

gam_te_pred_fine_df <- predict(te_model_interaction, fine_df[c('x1','x2')], type = 'link', se.fit = TRUE)

fine_gam_tensor_ti_pred_matrix <- fine_df %>% mutate(pred_log_y = gam_te_pred_fine_df$fit) %>%
  mutate(pred_log_y_lwr = gam_te_pred_fine_df$fit - 2*gam_te_pred_fine_df$se.fit) %>%
  mutate(pred_log_y_upr = gam_te_pred_fine_df$fit + 2*gam_te_pred_fine_df$se.fit)

fine_gam_tensor_ti_pred_matrix
## # A tibble: 1,681 x 12
##    fine_id x1_id x2_id    x1    x2        g1    g2        go      y pred_log_y
##      <int> <int> <int> <dbl> <dbl>     <dbl> <dbl>     <dbl>  <dbl>      <dbl>
##  1       1     1     1 -2       -2  1.00e+ 0     1  1.00e+ 0  0.987     1.05  
##  2       2     2     1 -1.9     -2  9.51e- 1     1  9.51e- 1  1.20      1.12  
##  3       3     3     1 -1.8     -2  8.09e- 1     1  8.09e- 1  1.25      1.08  
##  4       4     4     1 -1.7     -2  5.88e- 1     1  5.88e- 1  0.693     0.850 
##  5       5     5     1 -1.6     -2  3.09e- 1     1  3.09e- 1  0.430     0.468 
##  6       6     6     1 -1.5     -2 -1.84e-16     1 -1.84e-16 -0.174     0.0150
##  7       7     7     1 -1.4     -2 -3.09e- 1     1 -3.09e- 1 -0.563    -0.436 
##  8       8     8     1 -1.30    -2 -5.88e- 1     1 -5.88e- 1 -0.109    -0.818 
##  9       9     9     1 -1.2     -2 -8.09e- 1     1 -8.09e- 1 -1.03     -1.10  
## 10      10    10     1 -1.1     -2 -9.51e- 1     1 -9.51e- 1 -1.22     -1.26  
## # … with 1,671 more rows, and 2 more variables: pred_log_y_lwr <dbl>,
## #   pred_log_y_upr <dbl>
### focus on the x2-trend
fine_gam_tensor_ti_pred_matrix %>%
  filter(x1_id %in% seq(1, 41, by = 4),
         x2_id %in% seq(1, 41, by = 4)) %>% 
  ggplot(mapping = aes(x = x1)) +
  geom_ribbon(mapping = aes(ymin = pred_log_y_lwr, ymax = pred_log_y_upr,
                            group = x2),
              fill = "grey", alpha = 0.5) +
  geom_line(mapping = aes(y = pred_log_y,
                          group = x2),
            color = "black") +
  geom_point(mapping = aes(y = y),
             color = "red", size = 0.85) +
  facet_wrap(~ x2) +
  theme_bw()

### next look at x1 and x2 locations that were NOT in the training set
fine_gam_tensor_ti_pred_matrix %>% 
  filter(x1_id %in% seq(2, 41, by = 4),
         x2_id %in% seq(2, 41, by = 4)) %>% 
  ggplot(mapping = aes(x = x1)) +
  geom_ribbon(mapping = aes(ymin = pred_log_y_lwr, ymax = pred_log_y_upr,
                            group = x2),
              fill = "grey", alpha = 0.5) +
  geom_line(mapping = aes(y = pred_log_y,
                          group = x2),
            color = "blue") +
  geom_point(mapping = aes(y = y),
             color = "red", size = 0.85) +
  facet_wrap(~ x2, labeller = "label_both") +
  theme_bw()

fine_gam_tensor_ti_pred_matrix %>% 
  ggplot(mapping = aes(x = x1, y = x2)) +
  geom_raster(mapping = aes(fill = pred_log_y)) +
  scale_fill_viridis_c() +
  coord_equal() +
  theme_bw()